tree <- read.tree("data/comparison/newFormat/salmonData/salmonidsPhylo.newick")
gene.data <- getExprData("data/comparison/newFormat/salmonData/salmon20.tsv")

shiftSpecies <- c("Salp","Okis","Omyk","Ssal")
initParams <- initialParamsTwoTheta(gene.data,colSpecies = colnames(gene.data),shiftSpecies = shiftSpecies)

lowerBound <- c(theta1 = -99, theta2 = -99, sigma2 = 0.0001, alpha = 0.001, beta = 0.001)
upperBound <- c(theta1 =  99, theta2 =  99, sigma2 =   9999, alpha = 999  , beta = 99   )
parTransFuns <- list(
  none = list(
    trans = function(par){par},
    untrans = function(par){par}
  ),
  log = list(
    trans = function(par){
      with(as.list(par),
           c( theta1 = theta1, theta2 = theta2, 
              sigma2 = log(sigma2), alpha = log(alpha), beta = log(beta)))
    },
    untrans = function(par){
      with(as.list(par),
           c( theta1 = theta1, theta2 = theta2, 
              sigma2 = exp(sigma2), alpha = exp(alpha), beta = exp(beta)))
    }),
  rho = list(
    trans = function(par){
      with(as.list(par),
           c( theta1=theta1, theta2=theta2, alpha=alpha,
              rho=sigma2/(2*alpha), tau=beta*sigma2/(2*alpha)))
    },
    untrans = function(par){
      with(as.list(par),
           c( theta1=theta1, theta2=theta2, sigma2 = 2*rho*alpha,
              alpha=alpha, beta = tau/rho))
    })
)

fitTwoThetasTrans <- function( tree, shiftSpecies, colSpecies, gene.data.row, 
                               initParams, lowerBound, upperBound,
                               trans = c("none","log","rho"),
                               shiftOnStartOfBranch = F){
  
  transFun <- parTransFuns[[match.arg(trans)]]

  isThetaShiftEdge <- 1:Nedge(tree) %in% getEdgesFromMRCA(tree, tips = shiftSpecies)
  
  if( shiftOnStartOfBranch ) {
    isThetaShiftEdge[match(getMRCA(tree,shiftSpecies), tree$edge[ ,2])] <- TRUE
  }
  
  # match the column species with the phylogeny tip labels
  index.expand <- match(colSpecies, tree$tip.label)

  # function to optimize
  LLPerGeneTwoTheta <- function(par)
  {
    params <- transFun$untrans(par)
    ll <- logLikTwoTheta( theta1 = params[1], theta2 = params[2], sigma2 = params[3],
                          alpha = params[4], beta = params[5],
                          tree = tree, isThetaShiftEdge = isThetaShiftEdge, 
                          gene.data.row = gene.data.row, index.expand = index.expand)
    return(-ll)
  }
  
  # This doesn't seem to work..
  # if( trans == "rho"){
  #   # rho transform uses inverse of alpha, therefore swap upper and lower bound of alpha
  #   tmp <- lowerBound["alpha"]
  #   lowerBound["alpha"] <- upperBound["alpha"]
  #   upperBound["alpha"] <- tmp
  # }
    
  res <- optim( par = transFun$trans(initParams), fn = LLPerGeneTwoTheta, method = "L-BFGS-B",
                lower = transFun$trans(lowerBound), 
                upper = transFun$trans(upperBound))

  res$par <- transFun$untrans(res$par)
  
  return(res)
}
lapply(c("log","none") %>% setNames(.,.), function(trans){
  lapply(rownames(gene.data), function(geneID){
    fitTwoThetasTrans(tree = tree, shiftSpecies = shiftSpecies, colSpecies =colnames(gene.data), 
                      gene.data.row = gene.data[geneID, ], 
                      initParams = initParams[geneID, ], 
                      trans = trans,
                      lowerBound = lowerBound, upperBound = upperBound)
  })  
}) -> resAll

resTbl <- 
  lapply(resAll,function(resTrans){
    t(sapply(resTrans, function(res) res$par)) %>% 
    as.tibble() %>% 
    mutate( geneID=rownames(gene.data),
            ll = sapply(resTrans, function(res) -res$value),
            convergence = map_int(resTrans, "convergence"),
            iterations = sapply(resTrans, function(res) res$counts[1]))
  }) %>% 
  bind_rows(.id = "trans")

log transform and bounds

effect on iterations

resTbl %>% 
  mutate( geneID = factor(geneID,levels=rev(unique(geneID)))) %>% 
  mutate( convergence = convergence==0 ) %>% 
  ggplot( aes(x=geneID,y=iterations, fill=trans, color=convergence)) +
  geom_col(position = "dodge",width = 0.7) +
  scale_color_manual( values = c("green","black")) +
  coord_flip()

parameter estimates

resTbl %>%
  select(-theta1,-theta2, -iterations) %>% 
  group_by(geneID) %>%
  mutate( llDiff = ll - mean(ll)) %>% 
  ungroup() %>% 
  select(-ll) %>% 
  gather(key = "Parameter", value = "Value", -geneID, -trans, -convergence, -llDiff) %>%
  mutate( geneID = factor(geneID,levels=rev(unique(geneID)))) %>% 
  mutate( atUpperBound = Value > (upperBound[Parameter]*0.999)) %>%
  mutate( atLowerBound = Value < (lowerBound[Parameter]*1.001)) %>%
  mutate( atBounds = atUpperBound | atLowerBound) %>% 
  ggplot( aes(x=Value,y=geneID, color=llDiff, shape=trans)) + 
  geom_point( size=3) + 
  scale_shape_manual(values = c(3,4)) +
  scale_color_gradient2(low="blue",mid="black", high="red") +
  facet_grid(. ~ Parameter ,scales = "free") +
  theme(legend.position="bottom") +
  scale_x_log10() +
  geom_vline(data=tibble(x=upperBound[3:5],Parameter=names(upperBound[3:5])),aes(xintercept=x),
             color="red",linetype=3)

overfitting alpha to the mean?

Expected expression

# plot the expression on Y and time since shift on the X. + Expected expression trend line.

# get time since shift
mrcaNode <- getMRCA(tree,shiftSpecies)
spcNodes <- match(shiftSpecies,tree$tip.label)
timeSinceShift <- setNames(dist.nodes(tree)[mrcaNode,spcNodes], shiftSpecies)

for( gene in rownames(gene.data)){

  cat("\n\n####",gene,"\n")
  
  with(as.list(resTbl %>% filter(geneID==gene, trans=="log")),{
    trendLine <- 
      tibble(timeSinceShift = seq(0,max(timeSinceShift),length.out = 100)) %>% 
      mutate( expLevel = theta2 + (theta1-theta2) * exp(-alpha * timeSinceShift))
    
    tibble(expLevel=gene.data[gene,], spc=colnames(gene.data)) %>% 
      mutate( timeSinceShift = ifelse(spc %in% shiftSpecies,timeSinceShift[spc],0) ) %>% 
      ggplot( aes( x = timeSinceShift, y = expLevel)) +
      geom_point() +
      geom_line( data=trendLine, color="red",linetype=2) +
      geom_hline(yintercept=theta1, color="blue",linetype=3) +
      annotate( "text", label="theta1",x=0.02,y=theta1, color="blue", vjust=-0.5) +
      geom_hline(yintercept=theta2, color="blue",linetype=3) + 
      annotate( "text", label="theta2",x=0.02,y=theta2, color="blue", vjust=-0.5) +
      ggtitle(paste0(gene,"  sigma2 = ",signif(sigma2,3),"  alpha = ",signif(alpha,3), "  beta = ", signif(beta,3)))
  }) -> g
  print(g)
  
}

OG0000052_3

OG0000053_2

OG0000081_1

OG0000082_1

OG0000129_1

OG0000139_1

OG0000162_1

OG0000178_1

OG0000178_2

OG0000193_2

OG0000216_1

OG0000272_1

OG0000290_1

OG0000292_1

OG0000302_2

OG0000307_2

OG0000322_1

OG0000322_2

OG0000352_1

OG0000375_1

Alpha and Ssal expression

Beacause of the phylogeny where Ssal has a significantly shorter branch than the other species and the decision to define the theta shift at the most recent common ancestor of the salmonids

colSpeciesIndices <- split(1:ncol(gene.data), f = colnames(gene.data))
species.mean <- sapply(colSpeciesIndices, function(i){ rowMeans(gene.data[,i]) })

tibble(
  geneID = rownames(gene.data),
  preShift = rowMeans(species.mean[ ,c("Drer", "Eluc", "Olat")]),
  postShift = rowMeans(species.mean[ ,c("Salp", "Okis", "Omyk")]),
  Ssal = species.mean[ ,"Ssal"]
) %>% 
  mutate( relSsalShift = (Ssal-preShift) / (postShift-preShift)) %>% 
  left_join( resTbl %>% filter(trans=="log") %>% select(geneID,alpha), by="geneID") %>% 
  ggplot( aes(x=relSsalShift, y=alpha, color=alpha)) +
  scale_color_gradientn(colours = c("blue","black","red"),trans = "log", breaks = c(10^(-1:3))) +
  geom_point_interactive( aes(tooltip=geneID, data_id = geneID)) + 
  scale_y_log10() -> g

girafe(code = print(g))

Between-/Within-species variance

# calculate expected mean based on parameters
expLevel <- species.mean*0
for( gene in rownames(species.mean)){
  with(as.list(resTbl %>% filter(geneID==gene, trans=="log")),{
    for( spc in colnames(species.mean)){
      dt <- ifelse( spc %in% names(timeSinceShift),timeSinceShift,0)
      expLevel[gene,spc] <<- theta2 + (theta1-theta2) * exp(-alpha * dt)
    }
  })
}

# subtract the expected mean before calculating between species mean
species.mean.adjusted <- species.mean - expLevel

tibble(
  geneID = rownames(gene.data),
  betweenSpeciesVar = apply(species.mean.adjusted,1,var)
) %>% 
  left_join( resTbl %>% filter(trans=="log"), by="geneID") %>% 
  ggplot( aes(x=betweenSpeciesVar, y=sigma2/(2*alpha), color=alpha)) +
  scale_color_gradientn(colours = c("blue","black","red"),trans = "log", breaks = c(10^(-1:3))) +
  geom_point() +
  geom_point_interactive( aes(tooltip=geneID, data_id = geneID)) + 
  geom_abline( slope=1, intercept=0, color="red", linetype=2) +
  scale_x_log10() +
  scale_y_log10() -> g

girafe(code = print(g))  
tibble(
  geneID = rownames(gene.data),
  withinSpeciesVar = rowMeans(sapply(colSpeciesIndices, function(i){ apply(gene.data[,i],1,var) }))
) %>% 
  left_join( resTbl %>% filter(trans=="log"), by="geneID") %>% 
  ggplot( aes(x=withinSpeciesVar, y=beta*sigma2/(2*alpha), color=alpha)) +
  scale_color_gradientn(colours = c("blue","black","red"),trans = "log", breaks = c(10^(-1:3))) +
  geom_point() +
  geom_point_interactive( aes(tooltip=geneID, data_id = geneID)) + 
  geom_abline( slope=1, intercept=0, color="red", linetype=2) +
  scale_x_log10() +
  scale_y_log10() -> g

girafe(code = print(g))  

Late vs early theta-shift

In the model used above the theta-shift was placed at the point of the first diverging species in the salmonid clade. However, the WGD event, which is were we expect the shift, occurred earlier at some point along the branch leading to the salmonids. As we have seen above, the relatively short Ssal branch leads to overfitting to the Ssal expression. By moving the theta-shift to the start of the salmonid branch there is less difference in the time since shift and there should be less overfitting.

orangeEdges <- 1:Nedge(tree) %in% getEdgesFromMRCA(tree, tips = shiftSpecies)

plot.phylo(tree, edge.color = ifelse(orangeEdges,"orange","black"),
           tip.color = ifelse(tree$tip.label %in% shiftSpecies,"orange","black"),
           main = "Shift at root of clade")

orangeEdges[match(getMRCA(tree,shiftSpecies), tree$edge[ ,2])] <- TRUE

plot.phylo(tree, edge.color = ifelse(orangeEdges,"orange","black"),
           tip.color = ifelse(tree$tip.label %in% shiftSpecies,"orange","black"),
           main = "Shift at start of branch to clade")

lapply(rownames(gene.data), function(geneID){
    fitTwoThetasTrans(tree = tree, shiftSpecies = shiftSpecies, 
                      colSpecies =colnames(gene.data), 
                      gene.data.row = gene.data[geneID, ], 
                      initParams = initParams[geneID, ], 
                      trans = "log",
                      lowerBound = lowerBound, upperBound = upperBound,
                      shiftOnStartOfBranch = TRUE)
}) -> resEarlyShift


resTblEarly <- 
  t(sapply(resEarlyShift, function(res) res$par)) %>% 
  as.tibble() %>% 
  mutate( geneID=rownames(gene.data),
          ll = sapply(resEarlyShift, function(res) -res$value),
          convergence = map_int(resEarlyShift, "convergence"),
          iterations = sapply(resEarlyShift, function(res) res$counts[1]))

Compare estimated parameters

bind_rows(.id = "shiftPoint", 
          early = resTblEarly,
          late = filter(resTbl, trans =="log")) %>% 
  select(-theta1,-theta2, -iterations, -trans, -convergence) %>% 
  group_by(geneID) %>% 
  mutate( llDiff = ll - mean(ll)) %>% 
  ungroup() %>% 
  select(-ll) %>% 
  gather(key = "Parameter", value = "Value", sigma2, alpha, beta) %>%
  mutate( geneID = factor(geneID,levels=rev(unique(geneID)))) %>% 
  ggplot( aes(x=Value,y=geneID, color=llDiff, shape=shiftPoint)) + 
  geom_point( size=3) + 
  scale_shape_manual(values = c(3,4)) +
  scale_color_gradient2(low="blue",mid="black", high="red") +
  facet_grid(. ~ Parameter ,scales = "free") +
  theme(legend.position="bottom") +
  scale_x_log10() +
  geom_vline(data=tibble(x=upperBound[3:5],Parameter=names(upperBound[3:5])),aes(xintercept=x),
             color="red",linetype=3)

resTblEarly %>% 
  transmute(llEarly=ll,geneID=geneID) %>% 
  left_join(resTbl %>% filter( trans =="log") %>% select(ll,geneID), by="geneID") %>% 
  mutate( llDiff = llEarly - ll) %>% 
  mutate( geneID = factor(geneID,levels=rev(unique(geneID)))) %>% 
  ggplot( aes(y=llDiff, x=geneID, fill=llDiff >0)) + 
  geom_col() +
  scale_fill_manual(values = c("blue","red"), guide=F) +
  coord_flip() + 
  ylab("logL_early - logL_late")

The likelihoods, with few exceptions, are lower when using the early shift. Is this because the late shift model reflects the true process better or is it because the late shift phylogeny allows more overfitting?

Expected expression

# plot the expression on Y and time since shift on the X. + Expected expression trend line.

# get time since shift
earlyNode <- getMRCA(tree,c(shiftSpecies,"Eluc"))
spcNodes <- match(shiftSpecies,tree$tip.label)
timeSinceShift <- setNames(dist.nodes(tree)[earlyNode,spcNodes], shiftSpecies)

for( gene in rownames(gene.data)){

  cat("\n\n####",gene,"\n")
  
  with(as.list(resTblEarly %>% filter(geneID==gene)),{
    trendLine <- 
      tibble(timeSinceShift = seq(0,max(timeSinceShift),length.out = 100)) %>% 
      mutate( expLevel = theta2 + (theta1-theta2) * exp(-alpha * timeSinceShift))
    
    tibble(expLevel=gene.data[gene,], spc=colnames(gene.data)) %>% 
      mutate( timeSinceShift = ifelse(spc %in% shiftSpecies,timeSinceShift[spc],0) ) %>% 
      ggplot( aes( x = timeSinceShift, y = expLevel)) +
      geom_point() +
      geom_line( data=trendLine, color="red",linetype=2) +
      geom_hline(yintercept=theta1, color="blue",linetype=3) +
      annotate( "text", label="theta1",x=0.02,y=theta1, color="blue", vjust=-0.5) +
      geom_hline(yintercept=theta2, color="blue",linetype=3) + 
      annotate( "text", label="theta2",x=0.02,y=theta2, color="blue", vjust=-0.5) +
      ggtitle(paste0(gene,"  sigma2 = ",signif(sigma2,3),"  alpha = ",signif(alpha,3), "  beta = ", signif(beta,3)))
  }) -> g
  print(g)
  
}

OG0000052_3

OG0000053_2

OG0000081_1

OG0000082_1

OG0000129_1

OG0000139_1

OG0000162_1

OG0000178_1

OG0000178_2

OG0000193_2

OG0000216_1

OG0000272_1

OG0000290_1

OG0000292_1

OG0000302_2

OG0000307_2

OG0000322_1

OG0000322_2

OG0000352_1

OG0000375_1

Alpha vs Ssal relative expression

colSpeciesIndices <- split(1:ncol(gene.data), f = colnames(gene.data))
species.mean <- sapply(colSpeciesIndices, function(i){ rowMeans(gene.data[,i]) })

tibble(
  geneID = rownames(gene.data),
  preShift = rowMeans(species.mean[ ,c("Drer", "Eluc", "Olat")]),
  postShift = rowMeans(species.mean[ ,c("Salp", "Okis", "Omyk")]),
  Ssal = species.mean[ ,"Ssal"]
) %>% 
  mutate( relSsalShift = (Ssal-preShift) / (postShift-preShift)) %>% 
  left_join( resTblEarly %>% select(geneID,alpha), by="geneID") %>% 
  ggplot( aes(x=relSsalShift, y=alpha, color=alpha)) +
  scale_color_gradientn(colours = c("blue","black","red"),trans = "log", breaks = c(10^(-1:3))) +
  geom_point_interactive( aes(tooltip=geneID, data_id = geneID)) + 
  scale_y_log10() -> g

girafe(code = print(g))

With the early theta-shift the